-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add Conv3d forward function #412
base: master
Are you sure you want to change the base?
Conversation
impressive performance! we will review soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
developing conv3d is quite hard. you could refer to the implementation of conv2d and figure out a better way. if you are still confused, welcome to contact us in community by wechat.
- [16, 32, 120, 12, 12, 24, 3, 3, 3, 2, 1, 1] | ||
- [16, 32, 240, 24, 24, 24, 3, 3, 3, 1, 1, 2] | ||
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 2, 2, 2] | ||
- [16, 32, 24, 24, 24, 24, 3, 3, 3, 1, 2, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we suggest setting 5 shapes for core mode.
for k_w in [32 * i for i in range(1, 4)] | ||
for stride in [1, (2, 2, 2), (3, 3, 3)] | ||
for padding in [0, (1, 1, 1), (0, 1, 2)] | ||
for groups in [1, 2, 4, 8] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
personally I think the number of test cases is too large. pick 20 shapes from classic networks and the performance data is convincing enough.
bench = Conv3dBenchmark( | ||
input_fn=conv3d_input_fn, | ||
op_name="conv3d", | ||
torch_op=torch.nn.functional.conv3d, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since your implementation of conv3d is not registered into aten library, the function called in benchmark is still torch op. please add registration in src/flag_gems/init.py and update the benchmark results.
@@ -264,7 +264,7 @@ def set_more_shapes(self): | |||
|
|||
@pytest.mark.conv2d | |||
def test_perf_conv2d(): | |||
def conv2d_input_fn(shape, dtype, device): | |||
def conv3d_input_fn(shape, dtype, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not time to enable benchmark of conv2d right now. so does conv3d. I'll mark it as skip.
shape_input, dtype=dtype, device=flag_gems.device, requires_grad=True | ||
) | ||
ref_inp = to_reference(inp, True) | ||
torch.backends.cudnn.allow_tf32 = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually we set allow_tf32 as False since the precision of tf32 is not satisfying.
stride=strides, | ||
padding=paddings, | ||
dilation=dilations, | ||
).to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gems_assert_close will cast the reference tensor to dtype. you don't need to do this again.
for t in range(T): | ||
for r in range(R): | ||
for s in range(S): | ||
for c in range(C_per_group): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using multiple layers of loop might not be a good idea to compute convolution. try loading tensors with high-dimension indexes and using tl.dot primitive.
PR Category
Operator
Type of Change
New Feature
Description
Add Conv3d forward function and related tests
Issue
Progress
Performance
Operator: conv3d Performance Test (dtype=torch.float16, mode=cuda,level=core)
Operator: conv3d Performance Test (dtype=torch.float32, mode=cuda,level=core)
Operator: conv3d Performance Test (dtype=torch.bfloat16, mode=cuda,level=core)